import _pickle as cPickle
import numpy as np
from scipy.io.wavfile import read
from sklearn.mixture import GaussianMixture as GMM 
from speakerfeatures import extract_features
import warnings
import os,sys
warnings.filterwarnings("ignore")

def train_gmm2(source,dest):
    print(source, dest)
    file_paths = os.listdir(source)
    # Extracting features for each speaker (5 files per speakers)
    count = 1
    features = np.asarray(())
    for path in file_paths:
        people_path = os.path.join(source, path)
        len_ple = len(os.listdir(people_path))
        picklefile = path + ".gmm"
        if os.path.exists(os.path.join(dest, picklefile)):
            print(path, 'gmm model had been created!')
        else:
            print('Begin to train gmm model for ' , path)
            for wav in os.listdir(people_path):
                sr,audio = read(os.path.join(people_path ,wav))
            # extract 40 dimensional MFCC & delta MFCC features
                vector   = extract_features(audio,sr)

                if features.size == 0:
                    features = vector
                else:
                    features = np.vstack((features, vector))
                # when features of 5 files of speaker are concatenated, then do model training
                if count == len_ple:
                    #gmm = GMM(n_components = 16, n_iter = 200, covariance_type='diag',n_init = 3)
                    gmm = GMM(n_components = 16, max_iter = 400 , covariance_type='full',n_init = 4)
                    gmm.fit(features)

                    # dumping the trained gaussian model
                    #picklefile = wav.split("1")[0]+".gmm"
                    #picklefile = path + ".gmm"
                    cPickle.dump(gmm,open(os.path.join(dest ,picklefile),'wb'))
                    print('+ modeling completed for speaker:',picklefile," with data point = ",features.shape)
                    features = np.asarray(())
                    count = 0
                #print(wav , "finish")
                count = count + 1

if __name__ == '__main__':
    # path to training data
    #source = "audio/audio_wav3"
    source = sys.argv[1]#输入音频路径
    # path where training speakers will be saved
    #dest = "gmm2_models_3"
    dest = sys.argv[2]#模型保存路径
    train_gmm2(source, dest)
